{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: Population-Based Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### In this tutorial, we'll show you how to leverage Population-based Training.\n",
    "\n",
    "<img src=\"pbt.png\" alt=\"PBT\" width=\"600\"/>\n",
    "\n",
    "Tune is a scalable framework for model training and hyperparameter search with a focus on deep learning and deep reinforcement learning.\n",
    "\n",
    "* **Code**: https://github.com/ray-project/ray/tree/master/python/ray/tune \n",
    "* **Examples**: https://github.com/ray-project/ray/tree/master/python/ray/tune/examples\n",
    "* **Documentation**: http://ray.readthedocs.io/en/latest/tune.html\n",
    "* **Mailing List** https://groups.google.com/forum/#!forum/ray-dev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## If you are running on Google Colab, uncomment below to install the necessary dependencies \n",
    "## before beginning the exercise.\n",
    "\n",
    "# print(\"Setting up colab environment\")\n",
    "# !pip uninstall -y -q pyarrow\n",
    "# !pip install -q https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev5-cp36-cp36m-manylinux1_x86_64.whl\n",
    "# !pip install -q ray[debug]\n",
    "\n",
    "# # A hack to force the runtime to restart, needed to include the above dependencies.\n",
    "# print(\"Done installing! Restarting via forced crash (this is not an issue).\")\n",
    "# import os\n",
    "# os._exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "try:\n",
    "    tf.get_logger().setLevel('INFO')\n",
    "except Exception as exc:\n",
    "    print(exc)\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets\n",
    "from ray.tune.examples.mnist_pytorch import train, test, ConvNet, get_data_loaders\n",
    "\n",
    "import ray\n",
    "from ray import tune\n",
    "from ray.tune import track\n",
    "from ray.tune.schedulers import PopulationBasedTraining\n",
    "from ray.tune.utils import validate_save_restore\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.style as style\n",
    "import matplotlib.pyplot as plt\n",
    "style.use(\"ggplot\")\n",
    "\n",
    "datasets.MNIST(\"~/data\", train=True, download=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Trainable\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To utilize the PopulationBasedTraining Scheduler, we will have to use Tune's more extensive Class-based API. \n",
    "\n",
    "This API will allow Tune to take intermediate actions such as checkpointing and changing the hyperparameters in the middle of training.\n",
    "\n",
    "``train()`` wraps ``_train()``.\n",
    "\n",
    "A call to ``train()`` on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).\n",
    "\n",
    "### Instructions:\n",
    "\n",
    "Add training code under ``_train`` as follows:\n",
    "\n",
    "```python\n",
    "    def _train(self):\n",
    "        train(self.model, self.optimizer, self.train_loader, device=self.device)\n",
    "        acc = test(self.model, self.test_loader, self.device)\n",
    "        return {\"mean_accuracy\": acc}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PytorchTrainble(tune.Trainable):\n",
    "    def _setup(self, config):\n",
    "        self.device = torch.device(\"cpu\")\n",
    "        self.train_loader, self.test_loader = get_data_loaders()\n",
    "        self.model = ConvNet().to(self.device)\n",
    "        self.optimizer = optim.SGD(\n",
    "            self.model.parameters(),\n",
    "            lr=config.get(\"lr\", 0.01),\n",
    "            momentum=config.get(\"momentum\", 0.9))\n",
    "\n",
    "    def _train(self):\n",
    "        # TODO: Add training code here.\n",
    "        return {\"mean_accuracy\": acc}\n",
    "\n",
    "    def _save(self, checkpoint_dir):\n",
    "        checkpoint_path = os.path.join(checkpoint_dir, \"model.pth\")\n",
    "        torch.save(self.model.state_dict(), checkpoint_path)\n",
    "        return checkpoint_path\n",
    "\n",
    "    def _restore(self, checkpoint_path):\n",
    "        self.model.load_state_dict(torch.load(checkpoint_path))\n",
    "        \n",
    "    def reset_config(self, new_config):\n",
    "        del self.optimizer\n",
    "        self.optimizer = optim.SGD(\n",
    "            self.model.parameters(),\n",
    "            lr=new_config.get(\"lr\", 0.01),\n",
    "            momentum=new_config.get(\"momentum\", 0.9))\n",
    "        return True\n",
    "\n",
    "\n",
    "ray.shutdown()  # Restart Ray defensively in case the ray connection is lost. \n",
    "ray.init(log_to_driver=False)\n",
    "\n",
    "validate_save_restore(PytorchTrainble)\n",
    "validate_save_restore(PytorchTrainble, use_object_store=True)\n",
    "print(\"Success!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use population-based training with 2 samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PBT uses information from the rest of the population to refine the hyperparameters and direct computational resources to models which show promise. \n",
    "\n",
    "In PBT, a worker might copy the model parameters from a better performing worker. It can also explore new hyperparameters by changing the current values randomly (``hyperparam_mutations``).\n",
    "\n",
    "\n",
    "\n",
    "As the training of the population of neural networks progresses, this process of exploiting and exploring is performed periodically, ensuring that all the workers in the population have a good base level of performance and also that new hyperparameters are consistently explored.  This means that PBT can quickly exploit good hyperparameters, can dedicate more training time to promising models and, crucially, can adapt the hyperparameter values throughout training, leading to automatic learning of the best configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler = PopulationBasedTraining(\n",
    "    time_attr=\"training_iteration\",\n",
    "    metric=\"mean_accuracy\",\n",
    "    mode=\"max\",\n",
    "    perturbation_interval=5,\n",
    "    hyperparam_mutations={\n",
    "        # distribution for resampling\n",
    "        \"lr\": lambda: np.random.uniform(0.0001, 1),\n",
    "        # allow perturbations within this set of categorical values\n",
    "        \"momentum\": [0.8, 0.9, 0.99],\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "ray.shutdown()  # Restart Ray defensively in case the ray connection is lost. \n",
    "ray.init(log_to_driver=False)\n",
    "\n",
    "\n",
    "analysis = tune.run(\n",
    "    PytorchTrainble,\n",
    "    name=\"pbt_test\",\n",
    "    scheduler=scheduler,\n",
    "    reuse_actors=True,\n",
    "    verbose=1,\n",
    "    stop={\n",
    "        \"training_iteration\": 100,\n",
    "    },\n",
    "    num_samples=4,\n",
    "    \n",
    "    # PBT starts by training many neural networks in parallel with random hyperparameters. \n",
    "    config={\n",
    "        \"lr\": tune.uniform(0.001, 1),\n",
    "        \"momentum\": tune.uniform(0.001, 1),\n",
    "    })\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can use this to visualize all mutations of Population-based Training.\n",
    "! cat ~/ray_results/pbt_test/pbt_global.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizing the results of Population-based Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot by wall-clock time\n",
    "\n",
    "dfs = analysis.fetch_trial_dataframes()\n",
    "# This plots everything on the same plot\n",
    "ax = None\n",
    "for d in dfs.values():\n",
    "    ax = d.plot(\"training_iteration\", \"mean_accuracy\", ax=ax, legend=False)\n",
    "\n",
    "plt.xlabel(\"epoch\"); plt.ylabel(\"Test Accuracy\"); "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}